!! Copyright (C) 2009,2010,2011,2012  Marco Restelli
!!
!! This file is part of:
!!   LDGH -- Local Hybridizable Discontinuous Galerkin toolkit
!!
!! LDGH is free software: you can redistribute it and/or modify it
!! under the terms of the GNU General Public License as published by
!! the Free Software Foundation, either version 3 of the License, or
!! (at your option) any later version.
!!
!! LDGH is distributed in the hope that it will be useful, but WITHOUT
!! ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
!! or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
!! License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with LDGH. If not, see <http://www.gnu.org/licenses/>.
!!
!! author: Marco Restelli                   <marco.restelli@gmail.com>


module mod_linal
!General comments: provides some basic linear algebra operators.
!Linear systems are solved with Gaussian elimination.
!-----------------------------------------------------------------------

 use mod_messages, only: &
   mod_messages_initialized, &
   error,   &
   warning, &
   info

 use mod_kinds, only: &
   mod_kinds_initialized, &
   wp

 use complex_eigen, only: &
   dceigv

!-----------------------------------------------------------------------

 implicit none

!-----------------------------------------------------------------------

! Module interface

 public :: &
   mod_linal_constructor, &
   mod_linal_destructor,  &
   mod_linal_initialized, &
   ! eigenvalues
   eig,        &
   ! matrix inversion
   det,        & ! determinant
   invmat,     & ! Gauss fact., double pivoting
   invmat_nop, & ! Gauss fact., no pivoting
   invmat_chol,& ! Cholesky fact.
   ! linear systems
   linsys,     &
   linsys_chol,&
   lu,         &
   ! order values
   sort,       &
   fsort

 private

!-----------------------------------------------------------------------

! Module types and parameters

 ! public members

 ! private members

! Module variables

 ! public members
 logical, protected ::               &
   mod_linal_initialized = .false.

 ! private members
 character(len=*), parameter :: &
   this_mod_name = 'mod_linal'

 interface lu
   module procedure lu_p, lu_nop
 end interface lu

 interface sort
   module procedure sort_r, sort_i
 end interface sort
 interface fsort
   module procedure sort_i_f
 end interface fsort

!-----------------------------------------------------------------------

contains

!-----------------------------------------------------------------------

 subroutine mod_linal_constructor()

  character(len=*), parameter :: &
    this_sub_name = 'constructor'

   !Consistency checks ---------------------------
   if( (mod_messages_initialized.eqv..false.) .or. &
       (mod_kinds_initialized.eqv..false.) ) then
     call error(this_sub_name,this_mod_name, &
                'Not all the required modules are initialized.')
   endif
   if(mod_linal_initialized.eqv..true.) then
     call warning(this_sub_name,this_mod_name, &
                  'Module is already initialized.')
   endif
   !----------------------------------------------

   mod_linal_initialized = .true.
 end subroutine mod_linal_constructor

!-----------------------------------------------------------------------
 
 subroutine mod_linal_destructor()
  character(len=*), parameter :: &
    this_sub_name = 'destructor'
   
   !Consistency checks ---------------------------
   if(mod_linal_initialized.eqv..false.) then
   endif
   !----------------------------------------------

   mod_linal_initialized = .false.
 end subroutine mod_linal_destructor

!-----------------------------------------------------------------------
 
 pure subroutine eig(a,wr,wi,zr,zi,ierr) 
 ! Compute the eigenvalues of a square matrix.
 ! This function is indeed a simplified interface to the function
 ! dceigv from module complex_eigen.
   real(wp), intent(inout) :: a(:,:)          ! matrix
   real(wp), intent(out)   :: wr(:), wi(:)     ! eigenvalues
   real(wp), intent(out)   :: zr(:,:), zi(:,:) ! eigenvectors
   integer, intent(out)    :: ierr
   
   logical, parameter :: ibal = .true.
   integer :: n
   real(wp), dimension(size(a,1),size(a,2)) :: ai
   
   n = size(a,1)
   ai = 0.0_wp ! imaginary part

   call dceigv(ibal,a,ai,n,wr,wi,zr,zi,ierr)

 end subroutine eig
 
!-----------------------------------------------------------------------
 
 pure recursive function det(m) result(d)
 ! Matrix determinant; m must be square.
  real(wp), intent(in) :: m(:,:)
  real(wp) :: d

  integer :: n, j, skip_j(size(m,1)-1)

   n = size(m,1)
   if(n.eq.0) then
     d = 1.0_wp
   elseif(n.eq.1) then
     d = m(1,1);
   else
     d = 0.0_wp;
     skip_j = (/ (j, j=2,n) /)
     do j=1,n
       d = d + (-1)**(j-1)*m(1,j)*det( m(2:n,skip_j) );
       if(j.ne.n) skip_j(j) = j
     enddo
   endif
 
 end function det
 
!-----------------------------------------------------------------------
 
 pure subroutine invmat(a,ai)
 ! Compute a^{-1}, direct method.
 ! the matrix a is factored as 
 !  paq = lu
 ! and then n linear systems are solved
  real(wp), intent(in) :: a(:,:)
  real(wp), intent(out) :: ai(:,:)

  integer :: i
  real(wp), dimension(size(a,1),size(a,2)) :: l, u, p, q
  real(wp), dimension(size(a,1)) :: y
 
   call lu(a,l,u,p,q)
   do i=1,size(a,2) ! loop on the columns
     ! We solve the linear system  a * ai(:,i) = p * ei
     ! where ei = [0 ... 0 1 0 ... 0]. This gives the rhs p(:,i)
     call forwsubst(l,p(:,i),y)
     call backsubst(u,y,ai(:,i))
   enddo
   ! column permutation
   ai = matmul(q,ai)
 
 end subroutine invmat
 
!-----------------------------------------------------------------------

 pure subroutine invmat_nop(a,ai)
 ! Compute a^{-1}, direct method, no pivoting.
 ! the matrix a is factored as 
 !  a = lu
 ! and then n linear systems are solved
  real(wp), intent(in) :: a(:,:)
  real(wp), intent(out) :: ai(:,:)

  integer :: i, j, n
  real(wp) :: uf
  real(wp), dimension(size(a,1),size(a,2)) :: l, u, y
 
   call lu(a,l,u)

   n = size(a,1)

   ! Forward and backward substitutions are reimplemented together for
   ! the complete matrix, which is faster than doing n calls to
   ! forwsubst and backsubst.
   y = 0.0_wp
   do i=1,n ! loop on the rows
     y(i,i) = 1.0_wp/l(i,i)
     do j=1,i-1
       y(i,j) = -y(i,i)*dot_product( l(i,j:i-1) , y(j:i-1,j)  )
     enddo
   enddo
   
   do i=n,1,-1 ! loop on the rows
     uf = 1.0_wp/u(i,i)
     do j=1,i
       ai(i,j) = uf*( y(i,j) - dot_product( u(i,i+1:n) , ai(i+1:n,j) ) )
     enddo
     do j=i+1,n
       ai(i,j) = -uf*dot_product( u(i,i+1:n) , ai(i+1:n,j) )
     enddo
   enddo

 end subroutine invmat_nop
 
!-----------------------------------------------------------------------
 
 pure subroutine invmat_chol(a,ai)
 ! Compute a^{-1}, with the Cholesky factorization. a must be
 ! symmetric positive definite for this to make sense!
 ! The matrix a is factored as 
 !  a = ll^T
 ! and then n linear systems are solved. We take advantage, however,
 ! of the fact that a^{-1} is also symmetric.
  real(wp), intent(in) :: a(:,:)
  real(wp), intent(out) :: ai(:,:)

  integer :: i, j, n
  real(wp), dimension(size(a,1),size(a,2)) :: l, y

   n = size(a,1)
 
   !1) Cholesky factorization
   l = a
   do i=1,n-1
     l(i,i) = sqrt(l(i,i));
     l(i+1:n,i) = l(i+1:n,i)/l(i,i);
     do j=i+1,n
       l(j:n,j) = l(j:n,j) - l(j:n,i)*l(j,i);
     enddo
   enddo
   l(n,n) = sqrt(l(n,n));
   ! notice that only the lower triangular part of l is meaningful

   !2) solution of the triangular systems
   y = 0.0_wp
   do i=1,n ! loop on the rows
     y(i,i) = 1.0_wp/l(i,i)
     do j=1,i-1
       y(i,j) = -y(i,i)*dot_product( l(i,j:i-1) , y(j:i-1,j)  )
     enddo
   enddo
   
   do i=n,1,-1 ! loop on the rows
     do j=1,i-1
       ai(i,j) = y(i,i)*(y(i,j) - dot_product( l(i+1:n,i) , ai(i+1:n,j) ))
       ai(j,i) = ai(i,j)
     enddo
     ai(i,i) = y(i,i)*(y(i,i) - dot_product( l(i+1:n,i) , ai(i+1:n,i) ))
   enddo

 end subroutine invmat_chol
 
!-----------------------------------------------------------------------
 
 pure subroutine linsys(a,b,x)
 ! Solve a linear system, direct method. The matrix a is factorized as
 !   paq = lu
 ! and then the two systems are solved:
 !  ly = pb
 !  uz = y
 !  x  = qz
  real(wp), intent(in) :: a(:,:), b(:)
  real(wp), intent(out) :: x(:)

  real(wp), dimension(size(a,1),size(a,2)) :: l, u, p, q
  real(wp), dimension(size(a,1)) :: y, z
 
  call lu(a,l,u,p,q)
  call forwsubst(l,matmul(p,b),y)
  call backsubst(u,y,z)
  x = matmul(q,z)
 
 end subroutine linsys
 
!-----------------------------------------------------------------------

 pure subroutine linsys_chol(a,b,x)
 ! Solve a linear system, Cholesky method (only for symmetric,
 ! positive definite matrices).
  real(wp), intent(in) :: a(:,:), b(:)
  real(wp), intent(out) :: x(:)

  integer :: i, j, n
  real(wp) :: l(size(a,1),size(a,2)), y(size(a,1))

   n = size(a,1)
 
   !1) Cholesky factorization
   l = a
   do i=1,n-1
     l(i,i) = sqrt(l(i,i));
     l(i+1:n,i) = l(i+1:n,i)/l(i,i);
     do j=i+1,n
       l(j:n,j) = l(j:n,j) - l(j:n,i)*l(j,i);
     enddo
   enddo
   l(n,n) = sqrt(l(n,n));
   ! notice that only the lower triangular part of l is meaningful

   call forwsubst(l,b,y)
   call backsubst(transpose(l),y,x)

 end subroutine linsys_chol

!-----------------------------------------------------------------------
 
 pure subroutine lu_p(aa,l,u,p,q)
 ! lu factorization with row and column pivoting: paq = lu
 ! All the matrices must be square and of the same dimension
  real(wp), intent(in) :: aa(:,:) ! matrix to factorize
  real(wp), intent(out) :: l(:,:), u(:,:), & ! l,u factors
                           p(:,:), q(:,:)    ! permutations

  integer :: k
  real(wp), dimension(size(aa,1),size(aa,2)) :: a, minv, pk, qk, mk, mki

   ! make a local copy
   a = aa

   ! initialize the local matrices to identity
   call eye(p)
   q    = p
   minv = p

   do k=1,size(a,1)-1
     ! find the k-th pivot
     call pivot(a,k,pk,qk)
     ! apply the pivot
     a = matmul(pk,matmul(a,qk))
     ! gaussian elimination
     call mgauss(a,k,mk,mki)
     ! update the matrices
     a = matmul(mk,a)
     p = matmul(pk,p)
     q = matmul(q,qk)
     minv = matmul(minv,matmul(pk,mki))
   enddo

   u = a
   call triu(u)
   l = matmul(p,minv)

 end subroutine lu_p
 
!-----------------------------------------------------------------------
 
 pure subroutine lu_nop(aa,l,u)
 ! lu factorization without any pivoting: a = lu
 ! All the matrices must be square and of the same dimension
  real(wp), intent(in) :: aa(:,:) ! matrix to factorize
  real(wp), intent(out) :: l(:,:), u(:,:)

  integer :: k
  real(wp), dimension(size(aa,1),size(aa,2)) :: a, mk, mki

   ! make a local copy
   a = aa

   ! initialize the inverse m matrix
   call eye(l)

   do k=1,size(a,1)-1
     ! gaussian elimination
     call mgauss(a,k,mk,mki)
     ! update the matrices
     a = matmul(mk,a)
     l = matmul(l,mki)
   enddo

   u = a
   call triu(u)

 end subroutine lu_nop
 
!-----------------------------------------------------------------------
 
 pure subroutine sort_r(x,ind)
 ! Sort the vector x and apply the same permutation set to the index
 ! vector ind. By passing 
 !  ind = (/ (i, i=1,size(x)) /)
 ! one obtains an index vector such that
 !  x_sort = x(ind)
 ! This version uses a bubble sort alogorithm.
  real(wp), intent(inout) :: x(:)
  integer, intent(inout) :: ind(:)

  logical :: swapped
  integer :: n, i, j, jmax, itemp
  real(wp) :: temp
 
   ! Algorithm "Bubble Sort"
   n = size(x)
 
   jmax = n-1
   main_do: do i=1,n-1
     swapped = .false.
     do j=1,jmax
       if(x(j).gt.x(j+1)) then ! exchange
         ! swap x
         temp = x(j)
         x(j) = x(j+1)
         x(j+1) = temp
         ! swap ind
         itemp = ind(j)
         ind(j) = ind(j+1)
         ind(j+1) = itemp
         ! take notice
         swapped = .true.
       endif
     enddo
     if(.not.swapped) exit main_do ! done
     jmax = jmax-1
   enddo main_do
  
 end subroutine sort_r
 
!-----------------------------------------------------------------------
 
 pure subroutine sort_i(x,ind,p)
 ! Identical to sort_r but for integer values
  integer, intent(inout) :: x(:)
  integer, intent(inout) :: ind(:)
  integer, intent(out), optional :: p ! parity

  logical :: swapped
  integer :: n, i, j, jmax, itemp
 
   ! Algorithm "Bubble Sort"
   if(present(p)) p=1
   n = size(x)
 
   jmax = n-1
   main_do: do i=1,n-1
     swapped = .false.
     do j=1,jmax
       if(x(j).gt.x(j+1)) then ! exchange
         ! swap x
         itemp = x(j)
         x(j) = x(j+1)
         x(j+1) = itemp
         ! swap ind
         itemp = ind(j)
         ind(j) = ind(j+1)
         ind(j+1) = itemp
         ! take notice
         if(present(p)) p = -1*p
         swapped = .true.
       endif
     enddo
     if(.not.swapped) exit main_do ! done
     jmax = jmax-1
   enddo main_do
  
 end subroutine sort_i

!-----------------------------------------------------------------------

 pure function sort_i_f(x) result(y)
 ! Analogous to sort_i but implemented as a function
  integer, intent(in) :: x(:)
  integer :: y(size(x))

  logical :: swapped
  integer :: n, i, j, jmax, itemp
 
   ! Algorithm "Bubble Sort"
   n = size(x)
   y = x
 
   jmax = n-1
   main_do: do i=1,n-1
     swapped = .false.
     do j=1,jmax
       if(y(j).gt.y(j+1)) then ! exchange
         ! swap y
         itemp = y(j)
         y(j) = y(j+1)
         y(j+1) = itemp
         swapped = .true.
       endif
     enddo
     if(.not.swapped) exit main_do ! done
     jmax = jmax-1
   enddo main_do
  
 end function sort_i_f

!-----------------------------------------------------------------------
 
 pure subroutine eye(i)
 ! identity matrix; i must be square
  real(wp), intent(out) :: i(:,:)
 
  integer :: j

   i = 0.0_wp
   do j=1,size(i,1)
     i(j,j) = 1.0_wp
   enddo
 
 end subroutine eye
 
!-----------------------------------------------------------------------
 
 pure subroutine triu(u)
 ! nullify the lower triangular part
  real(wp), intent(inout) :: u(:,:)
 
  integer :: i

   do i=2,size(u,1)
     u(i,1:i-1) = 0.0_wp
   enddo
 
 end subroutine triu
 
!-----------------------------------------------------------------------
 
 pure subroutine forwsubst(l,b,x)
 ! Solve the lower triangular system  lx = b
  real(wp), intent(in) :: l(:,:), b(:)
  real(wp), intent(out) :: x(:)

  integer :: i
 
   do i=1,size(b)
     x(i) = (b(i) - sum(l(i,1:i-1)*x(1:i-1)))/l(i,i)
   enddo
 
 end subroutine forwsubst
 
!-----------------------------------------------------------------------
 
 pure subroutine backsubst(u,b,x)
 ! Solve the upper triangular system  ux = b
  real(wp), intent(in) :: u(:,:), b(:)
  real(wp), intent(out) :: x(:)

  integer :: i, n
 
   n = size(b)
   do i=n,1,-1
     x(i) = (b(i) - sum(u(i,i+1:n)*x(i+1:n)))/u(i,i)
   enddo
 
 end subroutine backsubst
 
!-----------------------------------------------------------------------
 
 pure subroutine pivot(a,k,pk,qk)
 ! Compute the k-th permutation matrices pk and qk, where the pivot
 ! element is the larges element of the submatrix a(k:n,k:n)
  integer, intent(in) :: k
  real(wp), intent(in) :: a(:,:)
  real(wp), intent(out) :: pk(:,:), qk(:,:)
 
  integer :: n, ijpiv(2)

   n = size(a,1)
   ! find the pivot element
   ijpiv = maxloc(abs(a(k:n,k:n)))+k-1

   call eye(pk)
   pk(ijpiv(1),ijpiv(1)) = 0.0_wp
   pk(k,k) = 0.0_wp
   pk(k,ijpiv(1)) = 1.0_wp
   pk(ijpiv(1),k) = 1.0_wp

   call eye(qk)
   qk(ijpiv(2),ijpiv(2)) = 0.0_wp
   qk(k,k) = 0.0_wp
   qk(k,ijpiv(2)) = 1.0_wp
   qk(ijpiv(2),k) = 1.0_wp
 
 end subroutine pivot
 
!-----------------------------------------------------------------------
 
 pure subroutine mgauss(a,k,mk,mki)
 ! Compute the k-th matrices mk and mki of the Gaussian elimination
  integer, intent(in) :: k
  real(wp), intent(in) :: a(:,:)
  real(wp), intent(out) :: mk(:,:), mki(:,:)

  integer :: i
 
   call eye(mk)
   do i=k+1,size(mk,1)
     mk(i,k) = -a(i,k)/a(k,k)
   enddo

   call eye(mki)
   mki = 2.0_wp*mki - mk
 
 end subroutine mgauss
 
!-----------------------------------------------------------------------

end module mod_linal

